{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow Dataset API\n",
    "\n",
    "**Learning Objectives**\n",
    "1. Learn how to use tf.data to read data from memory\n",
    "1. Learn how to use tf.data in a training loop\n",
    "1. Learn how to use tf.data to read data from disk\n",
    "1. Learn how to write production input pipelines with feature engineering (batching, shuffling, etc.)\n",
    "\n",
    "\n",
    "In this notebook, we will start by refactoring the linear regression we implemented in the previous lab so that it takes data from a`tf.data.Dataset`, and we will learn how to implement **stochastic gradient descent** with it. In this case, the original dataset will be synthetic and read by the `tf.data` API directly  from memory.\n",
    "\n",
    "In a second part, we will learn how to load a dataset with the `tf.data` API when the dataset resides on disk.\n",
    "\n",
    "Each learning objective will correspond to a __#TODO__ in the [student lab notebook](../labs/2_dataset_api.ipynb) -- try to complete that notebook first before reviewing this solution notebook.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.1\n"
     ]
    }
   ],
   "source": [
    "# The json module is mainly used to convert the python dictionary above into a JSON string that can be written into a file\n",
    "import json\n",
    "# The math module in python provides some mathematical functions\n",
    "import math\n",
    "# The OS module in python provides functions for interacting with the operating system\n",
    "import os\n",
    "# The pprint module provides a capability to `pretty-print` arbitrary Python data structures in a form which can be used as input to the interpreter\n",
    "from pprint import pprint\n",
    "\n",
    "# Here we'll import data processing libraries like numpy and tensorflow\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "# Here we'll show the currently installed version of TensorFlow\n",
    "print(tf.version.VERSION)\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data from memory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's consider the synthetic dataset of the previous section:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_POINTS = 10\n",
    "# The .constant() method will creates a constant tensor from a tensor-like object.\n",
    "X = tf.constant(range(N_POINTS), dtype=tf.float32)\n",
    "Y = 2 * X + 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin with implementing a function that takes as input\n",
    "\n",
    "\n",
    "- our $X$ and $Y$ vectors of synthetic data generated by the linear function $y= 2x + 10$\n",
    "- the number of passes over the dataset we want to train on (`epochs`)\n",
    "- the size of the batches the dataset (`batch_size`)\n",
    "\n",
    "and returns a `tf.data.Dataset`: "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Remark:** Note that the last batch may not contain the exact number of elements you specified because the dataset was exhausted.\n",
    "\n",
    "If you want batches with the exact same number of elements per batch, we will have to discard the last batch by\n",
    "setting:\n",
    "\n",
    "```python\n",
    "dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "```\n",
    "\n",
    "We will do that here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's define create_dataset() procedure\n",
    "# TODO 1\n",
    "def create_dataset(X, Y, epochs, batch_size):\n",
    "# Using the tf.data.Dataset.from_tensor_slices() method we are able to get the slices of list or array\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((X, Y))\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size, drop_remainder=True)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test our function by iterating twice over our dataset in batches of 3 datapoints:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: [0. 1. 2.] y: [10. 12. 14.]\n",
      "x: [3. 4. 5.] y: [16. 18. 20.]\n",
      "x: [6. 7. 8.] y: [22. 24. 26.]\n",
      "x: [9. 0. 1.] y: [28. 10. 12.]\n",
      "x: [2. 3. 4.] y: [14. 16. 18.]\n",
      "x: [5. 6. 7.] y: [20. 22. 24.]\n"
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 3\n",
    "EPOCH = 2\n",
    "\n",
    "dataset = create_dataset(X, Y, epochs=EPOCH, batch_size=BATCH_SIZE)\n",
    "\n",
    "for i, (x, y) in enumerate(dataset):\n",
    "# You can convert a native TF tensor to a NumPy array using .numpy() method\n",
    "# Let's output the value of `x` and `y`\n",
    "    print(\"x:\", x.numpy(), \"y:\", y.numpy())\n",
    "    assert len(x) == BATCH_SIZE\n",
    "    assert len(y) == BATCH_SIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss function and gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The loss function and the function that computes the gradients are the same as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's define loss_mse() procedure which will return computed mean of elements across dimensions of a tensor.\n",
    "def loss_mse(X, Y, w0, w1):\n",
    "    Y_hat = w0 * X + w1\n",
    "    errors = (Y_hat - Y)**2\n",
    "    return tf.reduce_mean(errors)\n",
    "\n",
    "\n",
    "# Let's define compute_gradients() procedure which will return value of recorded operations for automatic differentiation\n",
    "def compute_gradients(X, Y, w0, w1):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = loss_mse(X, Y, w0, w1)\n",
    "    return tape.gradient(loss, [w0, w1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main difference now is that now, in the traning loop, we will iterate directly on the `tf.data.Dataset` generated by our `create_dataset` function. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STEP 0 - loss: 109.76800537109375, w0: 0.23999999463558197, w1: 0.4399999976158142\n",
      "\n",
      "STEP 100 - loss: 9.363959312438965, w0: 2.55655837059021, w1: 6.674341678619385\n",
      "\n",
      "STEP 200 - loss: 1.393267273902893, w0: 2.2146825790405273, w1: 8.717182159423828\n",
      "\n",
      "STEP 300 - loss: 0.20730558037757874, w0: 2.082810878753662, w1: 9.505172729492188\n",
      "\n",
      "STEP 400 - loss: 0.03084510937333107, w0: 2.03194260597229, w1: 9.809128761291504\n",
      "\n",
      "STEP 500 - loss: 0.004589457996189594, w0: 2.012321710586548, w1: 9.926374435424805\n",
      "\n",
      "STEP 600 - loss: 0.0006827632314525545, w0: 2.0047526359558105, w1: 9.971602439880371\n",
      "\n",
      "STEP 700 - loss: 0.00010164897685172036, w0: 2.0018346309661865, w1: 9.989042282104492\n",
      "\n",
      "STEP 800 - loss: 1.5142451957217418e-05, w0: 2.000706911087036, w1: 9.995771408081055\n",
      "\n",
      "STEP 900 - loss: 2.256260358990403e-06, w0: 2.0002737045288086, w1: 9.998367309570312\n",
      "\n",
      "STEP 1000 - loss: 3.3405058275093324e-07, w0: 2.000105381011963, w1: 9.999371528625488\n",
      "\n",
      "STEP 1100 - loss: 4.977664502803236e-08, w0: 2.000040054321289, w1: 9.999757766723633\n",
      "\n",
      "STEP 1200 - loss: 6.475602276623249e-09, w0: 2.0000154972076416, w1: 9.99991226196289\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Here we will configure the dataset so that it iterates 250 times over our synthetic dataset in batches of 2.\n",
    "# TODO 2\n",
    "EPOCHS = 250\n",
    "BATCH_SIZE = 2\n",
    "LEARNING_RATE = .02\n",
    "\n",
    "MSG = \"STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\\n\"\n",
    "\n",
    "w0 = tf.Variable(0.0)\n",
    "w1 = tf.Variable(0.0)\n",
    "\n",
    "dataset = create_dataset(X, Y, epochs=EPOCHS, batch_size=BATCH_SIZE)\n",
    "\n",
    "for step, (X_batch, Y_batch) in enumerate(dataset):\n",
    "\n",
    "    dw0, dw1 = compute_gradients(X_batch, Y_batch, w0, w1)\n",
    "    w0.assign_sub(dw0 * LEARNING_RATE)\n",
    "    w1.assign_sub(dw1 * LEARNING_RATE)\n",
    "\n",
    "    if step % 100 == 0:\n",
    "        loss = loss_mse(X_batch, Y_batch, w0, w1)\n",
    "        print(MSG.format(step=step, loss=loss, w0=w0.numpy(), w1=w1.numpy()))\n",
    "        \n",
    "assert loss < 0.0001\n",
    "assert abs(w0 - 2) < 0.001\n",
    "assert abs(w1 - 10) < 0.001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data from disk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Locating the CSV files\n",
    "\n",
    "We will start with the **taxifare dataset** CSV files that we wrote out in a previous lab. \n",
    "\n",
    "The taxifare dataset files have been saved into `../toy_data`.\n",
    "\n",
    "Check that it is the case in the cell below, and, if not, regenerate the taxifare\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-rw-r--r-- 1 jupyter jupyter  61473 Nov  7 21:59 ../toy_data/taxi-test.csv\n",
      "-rw-r--r-- 1 jupyter jupyter 288831 Nov  7 21:59 ../toy_data/taxi-train.csv\n",
      "-rw-r--r-- 1 jupyter jupyter  61082 Nov  7 21:59 ../toy_data/taxi-valid.csv\n"
     ]
    }
   ],
   "source": [
    "# ls shows the working directory's contents.\n",
    "# Using -l parameter will lists the files with assigned permissions\n",
    "!ls -l ../toy_data/taxi*.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use tf.data to read the CSV files\n",
    "\n",
    "The `tf.data` API can easily read csv files using the helper function tf.data.experimental.make_csv_dataset\n",
    "\n",
    "If you have TFRecords (which is recommended), you may use tf.data.experimental.make_batched_features_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first step is to define \n",
    "\n",
    "- the feature names into a list `CSV_COLUMNS`\n",
    "- their default values into a list `DEFAULTS`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the feature names into a list `CSV_COLUMNS`\n",
    "CSV_COLUMNS = [\n",
    "    'fare_amount',\n",
    "    'pickup_datetime',\n",
    "    'pickup_longitude',\n",
    "    'pickup_latitude',\n",
    "    'dropoff_longitude',\n",
    "    'dropoff_latitude',\n",
    "    'passenger_count',\n",
    "    'key'\n",
    "]\n",
    "LABEL_COLUMN = 'fare_amount'\n",
    "# Defining the default values into a list `DEFAULTS`\n",
    "DEFAULTS = [[0.0], ['na'], [0.0], [0.0], [0.0], [0.0], [0.0], ['na']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now wrap the call to `make_csv_dataset` into its own function that will take only the file pattern (i.e. glob) where the dataset files are to be located:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<PrefetchDataset element_spec=OrderedDict([('fare_amount', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('pickup_datetime', TensorSpec(shape=(1,), dtype=tf.string, name=None)), ('pickup_longitude', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('pickup_latitude', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('dropoff_longitude', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('dropoff_latitude', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('passenger_count', TensorSpec(shape=(1,), dtype=tf.float32, name=None)), ('key', TensorSpec(shape=(1,), dtype=tf.string, name=None))])>\n"
     ]
    }
   ],
   "source": [
    "# TODO 3\n",
    "def create_dataset(pattern):\n",
    "# The tf.data.experimental.make_csv_dataset() method reads CSV files into a dataset\n",
    "    return tf.data.experimental.make_csv_dataset(\n",
    "        pattern, 1, CSV_COLUMNS, DEFAULTS)\n",
    "\n",
    "\n",
    "tempds = create_dataset('../toy_data/taxi-train*')\n",
    "# Let's output the value of `tempds`\n",
    "print(tempds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that this is a prefetched dataset, where each element is an `OrderedDict` whose keys are the feature names and whose values are tensors of shape `(1,)` (i.e. vectors).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'dropoff_latitude': array([40.717358], dtype=float32),\n",
      " 'dropoff_longitude': array([-73.9839], dtype=float32),\n",
      " 'fare_amount': array([13.5], dtype=float32),\n",
      " 'key': array([b'132'], dtype=object),\n",
      " 'passenger_count': array([1.], dtype=float32),\n",
      " 'pickup_datetime': array([b'2012-12-16 20:38:58 UTC'], dtype=object),\n",
      " 'pickup_latitude': array([40.758236], dtype=float32),\n",
      " 'pickup_longitude': array([-73.97846], dtype=float32)}\n",
      "\n",
      "\n",
      "{'dropoff_latitude': array([40.7267], dtype=float32),\n",
      " 'dropoff_longitude': array([-73.989174], dtype=float32),\n",
      " 'fare_amount': array([7.], dtype=float32),\n",
      " 'key': array([b'107'], dtype=object),\n",
      " 'passenger_count': array([1.], dtype=float32),\n",
      " 'pickup_datetime': array([b'2015-02-07 08:21:20 UTC'], dtype=object),\n",
      " 'pickup_latitude': array([40.742184], dtype=float32),\n",
      " 'pickup_longitude': array([-73.97483], dtype=float32)}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Let's iterate over the first two element of this dataset using `dataset.take(2)`.\n",
    "# Then convert them ordinary Python dictionary with numpy array as values for more readability:\n",
    "for data in tempds.take(2):\n",
    "    pprint({k: v.numpy() for k, v in data.items()})\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transforming the features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What we really need is a dictionary of features + a label. So, we have to do two things to the above dictionary:\n",
    "\n",
    "1. Remove the unwanted column \"key\"\n",
    "1. Keep the label separate from the features\n",
    "\n",
    "Let's first implement a function that takes as input a row (represented as an `OrderedDict` in our `tf.data.Dataset` as above) and then returns a tuple with two elements:\n",
    "\n",
    "* The first element being the same `OrderedDict` with the label dropped\n",
    "* The second element being the label itself (`fare_amount`)\n",
    "\n",
    "Note that we will need to also remove the `key` and `pickup_datetime` column, which we won't use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "UNWANTED_COLS = ['pickup_datetime', 'key']\n",
    "\n",
    "# Let's define the features_and_labels() method\n",
    "# TODO 4a\n",
    "def features_and_labels(row_data):\n",
    "# The .pop() method will return item and drop from frame. \n",
    "    label = row_data.pop(LABEL_COLUMN)\n",
    "    features = row_data\n",
    "    \n",
    "    for unwanted_col in UNWANTED_COLS:\n",
    "        features.pop(unwanted_col)\n",
    "\n",
    "    return features, label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's iterate over 2 examples from our `tempds` dataset and apply our `feature_and_labels`\n",
    "function to each of the examples to make sure it's working:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('pickup_longitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-73.9823], dtype=float32)>),\n",
      "             ('pickup_latitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([40.772713], dtype=float32)>),\n",
      "             ('dropoff_longitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-73.96748], dtype=float32)>),\n",
      "             ('dropoff_latitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([40.762695], dtype=float32)>),\n",
      "             ('passenger_count',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>)])\n",
      "tf.Tensor([8.5], shape=(1,), dtype=float32) \n",
      "\n",
      "OrderedDict([('pickup_longitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-73.98736], dtype=float32)>),\n",
      "             ('pickup_latitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([40.760174], dtype=float32)>),\n",
      "             ('dropoff_longitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-74.00966], dtype=float32)>),\n",
      "             ('dropoff_latitude',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([40.705006], dtype=float32)>),\n",
      "             ('passenger_count',\n",
      "              <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>)])\n",
      "tf.Tensor([28.], shape=(1,), dtype=float32) \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for row_data in tempds.take(2):\n",
    "    features, label = features_and_labels(row_data)\n",
    "    pprint(features)\n",
    "    print(label, \"\\n\")\n",
    "    assert UNWANTED_COLS[0] not in features.keys()\n",
    "    assert UNWANTED_COLS[1] not in features.keys()\n",
    "    assert label.shape == [1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now refactor our `create_dataset` function so that it takes an additional argument `batch_size` and batch the data correspondingly. We will also use the `features_and_labels` function we implemented for our dataset to produce tuples of features and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's define the create_dataset() method\n",
    "# TODO 4b\n",
    "def create_dataset(pattern, batch_size):\n",
    "# The tf.data.experimental.make_csv_dataset() method reads CSV files into a dataset\n",
    "    dataset = tf.data.experimental.make_csv_dataset(\n",
    "        pattern, batch_size, CSV_COLUMNS, DEFAULTS)\n",
    "    return dataset.map(features_and_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test that our batches are of the right size:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'dropoff_latitude': array([40.762814, 40.74584 ], dtype=float32),\n",
      " 'dropoff_longitude': array([-73.9767  , -74.001915], dtype=float32),\n",
      " 'passenger_count': array([1., 1.], dtype=float32),\n",
      " 'pickup_latitude': array([40.766552, 40.75046 ], dtype=float32),\n",
      " 'pickup_longitude': array([-73.969025, -73.979   ], dtype=float32)}\n",
      "[ 6. 13.] \n",
      "\n",
      "{'dropoff_latitude': array([40.741245, 40.772743], dtype=float32),\n",
      " 'dropoff_longitude': array([-73.98776 , -73.967834], dtype=float32),\n",
      " 'passenger_count': array([1., 1.], dtype=float32),\n",
      " 'pickup_latitude': array([40.74568, 40.82851], dtype=float32),\n",
      " 'pickup_longitude': array([-74.00553 , -73.948814], dtype=float32)}\n",
      "[ 5.3 25. ] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 2\n",
    "\n",
    "tempds = create_dataset('../toy_data/taxi-train*', batch_size=2)\n",
    "\n",
    "for X_batch, Y_batch in tempds.take(2):\n",
    "    pprint({k: v.numpy() for k, v in X_batch.items()})\n",
    "    print(Y_batch.numpy(), \"\\n\")\n",
    "    assert len(Y_batch) == BATCH_SIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Shuffling\n",
    "\n",
    "When training a deep learning model in batches over multiple workers, it is helpful if we shuffle the data. That way, different workers will be working on different parts of the input file at the same time, and so averaging gradients across workers will help. Also, during training, we will need to read the data indefinitely."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's refactor our `create_dataset` function so that it shuffles the data, when the dataset is used for training.\n",
    "\n",
    "We will introduce an additional argument `mode` to our function to allow the function body to distinguish the case\n",
    "when it needs to shuffle the data (`mode == \"train\"`) from when it shouldn't (`mode == \"eval\"`).\n",
    "\n",
    "Also, before returning we will want to prefetch 1 data point ahead of time (`dataset.prefetch(1)`) to speed-up training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO 4c\n",
    "def create_dataset(pattern, batch_size=1, mode=\"eval\"):\n",
    "# The tf.data.experimental.make_csv_dataset() method reads CSV files into a dataset\n",
    "    dataset = tf.data.experimental.make_csv_dataset(\n",
    "        pattern, batch_size, CSV_COLUMNS, DEFAULTS)\n",
    "\n",
    "# The map() function executes a specified function for each item in an iterable.\n",
    "# The item is sent to the function as a parameter.\n",
    "    dataset = dataset.map(features_and_labels).cache()\n",
    "\n",
    "    if mode == \"train\":\n",
    "        dataset = dataset.shuffle(1000).repeat()\n",
    "\n",
    "    # take advantage of multi-threading; 1=AUTOTUNE\n",
    "    dataset = dataset.prefetch(1)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check that our function works well in both modes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(OrderedDict([('pickup_longitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-74.00861, -73.98522], dtype=float32)>), ('pickup_latitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([40.71599 , 40.718967], dtype=float32)>), ('dropoff_longitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-73.98326, -73.98102], dtype=float32)>), ('dropoff_latitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([40.695023, 40.68868 ], dtype=float32)>), ('passenger_count', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>)]), <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 9.3, 10.1], dtype=float32)>)]\n"
     ]
    }
   ],
   "source": [
    "tempds = create_dataset('../toy_data/taxi-train*', 2, \"train\")\n",
    "print(list(tempds.take(1)))\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(OrderedDict([('pickup_longitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-73.998535, -73.99426 ], dtype=float32)>), ('pickup_latitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([40.760765, 40.745285], dtype=float32)>), ('dropoff_longitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-73.983765, -74.00697 ], dtype=float32)>), ('dropoff_latitude', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([40.74651, 40.72942], dtype=float32)>), ('passenger_count', <tf.Tensor: shape=(2,), dtype=float32, numpy=array([4., 1.], dtype=float32)>)]), <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6.9, 6.5], dtype=float32)>)]\n"
     ]
    }
   ],
   "source": [
    "tempds = create_dataset('../toy_data/taxi-valid*', 2, \"eval\")\n",
    "print(list(tempds.take(1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2021 Google Inc.\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
    "http://www.apache.org/licenses/LICENSE-2.0\n",
    "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-8.m93",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-8:m93"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
